# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-unconditionally-use-set-output-shapes-on-import -tf-enable-shape-inference-on-import=false -tf-graph-as-function %s -o -  | FileCheck %s
# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-unconditionally-use-set-output-shapes-on-import -tf-enable-shape-inference-on-import=true -tf-graph-as-function %s -o - | FileCheck %s

# Verify importing with _output_shapes enabled works as expected.

node {
  name: "_Arg"
  op: "_Arg"
  attr {
    key: "T"
    value {
      type: DT_STRING
    }
  }
  attr {
    key: "index"
    value {
      i: 0
    }
  }
  attr {
    key: "_output_shapes"
    value {
      list {
        shape {
          dim {
            size: 1
          }
        }
      }
    }
  }
}
node {
  name: "Const0"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_INT64
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_INT64
        tensor_shape {
        }
        int64_val: -1
      }
    }
  }
  attr {
    key: "_output_shapes"
    value {
      list {
        shape {
        }
      }
    }
  }
}
node {
  name: "Const1"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_STRING
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_STRING
        tensor_shape {
        }
        string_val: ""
      }
    }
  }
  attr {
    key: "_output_shapes"
    value {
      list {
        shape {
        }
      }
    }
  }
}
node {
  name: "Const2"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_INT64
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_INT64
        tensor_shape {
          dim {
            size: 2
          }
        }
        tensor_content: "\x00\x00\x00\x00\x00\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x00"
      }
    }
  }
  attr {
    key: "_output_shapes"
    value {
      list {
        shape {
          dim {
            size: 2
          }
        }
      }
    }
  }
}
node {
  name: "RaggedToTensor"
  op: "RaggedTensorToTensor"
  input: ["Const0", "_Arg", "Const1", "Const2"]
  attr {
    key: "T"
    value {
      type: DT_STRING
    }
  }
  attr {
    key: "row_partition_types"
    value {
      list {
        s: ["ROW_SPLITS"]
      }
    }
  }
  attr {
    key: "Tshape"
    value {
      type: DT_INT64
    }
  }
  attr {
    key: "Tindex"
    value {
      type: DT_INT64
    }
  }
  attr {
    key: "T"
    value {
      type: DT_STRING
    }
  }
  attr {
    key: "num_row_partition_tensors"
    value {
      i: 1
    }
  }
  attr {
    key: "_output_shapes"
    value {
      list {
        shape {
          dim {
            size: 1
          }
          dim {
            size: 1
          }
        }
      }
    }
  }
}
# Return value to ensure not pruned if we want to run passes post.
node {
  name: "rets_0"
  op: "_Retval"
  input: "RaggedToTensor"
  attr {
    key: "T"
    value {
      type: DT_STRING
    }
  }
  attr {
    key: "index"
    value {
      i: 0
    }
  }
}

# CHECK: tf.RaggedTensorToTensor{{.+}}-> tensor<1x1x!tf_type.string>
